{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "multi-task-learning-example-pytorch.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "6BaMKLQvE8OO",
        "colab_type": "code",
        "outputId": "d9037787-b981-4975-d670-1ab5df35d47c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import numpy as np\n",
        "np.random.seed(0)\n",
        "torch.manual_seed(0)"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<torch._C.Generator at 0x7fec76671df0>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 1
        }
      ]
    },
    {
      "metadata": {
        "id": "tu9_J0xrIHhj",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "N = 100\n",
        "nb_epoch = 2000\n",
        "batch_size = 20\n",
        "nb_features = 1024\n",
        "Q = 1\n",
        "nb_output = 2  # total number of output\n",
        "D1 = 1  # first output\n",
        "D2 = 1  # second output"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ObaYZBVpyO3Z",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Evaluate on synthetic data"
      ]
    },
    {
      "metadata": {
        "id": "aoKriQlrIM3A",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def gen_data(N):\n",
        "    X = np.random.randn(N, Q)\n",
        "    w1 = 2.\n",
        "    b1 = 8.\n",
        "    sigma1 = 1e1  # ground truth\n",
        "    Y1 = X.dot(w1) + b1 + sigma1 * np.random.randn(N, D1)\n",
        "    w2 = 3\n",
        "    b2 = 3.\n",
        "    sigma2 = 1e0  # ground truth\n",
        "    Y2 = X.dot(w2) + b2 + sigma2 * np.random.randn(N, D2)\n",
        "    return X, Y1, Y2"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Db4B3XIzIPi-",
        "colab_type": "code",
        "outputId": "9ce79361-d020-4c16-e123-a5b51872f786",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 129
        }
      },
      "cell_type": "code",
      "source": [
        "import pylab\n",
        "%matplotlib inline\n",
        "\n",
        "X, Y1, Y2 = gen_data(N)\n",
        "pylab.figure(figsize=(3, 1.5))\n",
        "pylab.scatter(X[:, 0], Y1[:, 0])\n",
        "pylab.scatter(X[:, 0], Y2[:, 0])\n",
        "pylab.show()"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAM0AAABwCAYAAAC96PTCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHC1JREFUeJztnXlwHNd95z89PZgZgINjcFAEeEgi\nCTRlSbyVSLZkijQp0o4lsXQx4lo5bKXslHdLe7h2k0pqvavNxrtOstnK1pYrtuNSylo6ii2HK62V\n2GJ4SKREkeIFiiIbACmSIgESx1wYAtMz0937x2DAAdDdc/UMhmB/qlhFdPd0v5553/d+v9/7vfcE\nXddxcHDIH9dsF8DB4VbDEY2DQ4E4onFwKBBHNA4OBeKIxsGhQNyz+fChodGKhu4CgTpCobFKPrIk\nlKRKJKbQ6PfirREL/vyt9r52YOc7t7XVC0bHZ1U0lcbtLrzizQaqpvHa3j5O9AwRjCo0N3hZ09XG\njk3LEV35GwfZ71uqAKuBfN6hEr/xbSWaW4XX9vax58Mrk3+PRJXJv3du7iroXnYJcDaptne4Nb61\n2wglqXKiZ8jw3ImeYZSkWtD9MgIciSro3BTga3v7bChtZai2d3BEU2VEYgrBqGJ4LjQaJxIzPmdE\nPJGyVYCzgd2NiB04oqkyGv1emhu8hucC9T4a/cbnjAhF7RPgbGFnI2IXjmiqDG+NyJquNsNza7pa\nLZ14JakyGBqbbH0DDVYC9BYkwNnCzkbELpxAQBWyY9NyIG1+hEbjBOp9rOlqZfsjdzMYGpsRPRpT\nkux6u5dzl4KERhOTjvK/fG4Ndb4aRgxaaiWp4hYNI6pVRaYRyQ6MZMjViJQLYTaznCs9TtPWVs/Q\n0GglHwkUH+7NfM5f52H3uxdmRI+eeXQpP9t/gYPdA8QTM237L332Lg6f7ic4mjC8/8Y1HbywdUVJ\nZcxVdjvudzN6NrURMYqe2fkbm43TOKIpI3aFSnft6TFsaRfP9/PpYMz0c80NXkITEScjmvwe/uT3\nHjQUZLHh3HKGh/MRYiVE45hnZcSO8Rar6NHVIXPBQDoQUF9XQ3QsaXg+Ekvwk7d7OPTRtZLKmI2d\nY0zT8daIzA/UlXQPO3ACAWXCrlCpVfRIy9FPtwVqTYMKkA4GnLscKrmMGaoxPFwOHNGUCbtCpVbR\nI1cOP/7B+9r5ymNdLJ7vNzy/4s6AaRlHonGC0XheZcxQjeHhcuCIpkzYFSq1CkEvbDMWg88jsnn9\nIr76+L2ILhf/8XfWs3FNB01+DwLQ0uBj8/pF7NzSaVpGgD0ffppXGTNUY3g4w/RwfCk4Pk2ZsDNU\nahaCzkTPbh73smJJgOe3dFHndSOK6TZRdLl4YesKnts005G+b1kLB070Gz63+3wQJanmXdZqDA+X\nIzDhiKaMmFX2zPF8EV0udm7u4ukNy2ZUerPjRhg50smkZnp9xqQqxPm2653tohyBCUc0ZcSqsheD\nWfSo2KiSklSRTQIBAE1+86wBs/BvKe9cjrEiq8DE0xuWFfUcRzQVwKpSz+Y8FyvHHdKBgullytfc\nKUTIqqaxa08vJ3uGCcesTahCvq98AhPFNDYliUaSpO8Cj0zc5zvAUeDHgAgMAC/Isjw3QiY2U6k5\nIlaVLOO4G6XZ+DwiO7d0zjhejLljVQZV03j5lQ+nDNIa3bOY78vq/UoJTBQtGkmSNgL3ybL8kCRJ\nLcAJ4J+B/y3L8k8lSfpT4KvA94p9xlymnIOAkK5kP9h9mkOnrppWMivH/eGV7dR5a6YcU5Iqx+VB\nw+cZmTv5VPRdb/eYZjVk39Pq+8o2BbMpV2CilJ7mHeDIxP/DwDzgUeAbE8feBL6FI5pJMi1urddd\nFls7m3xFma/jrmoar/5SNs1jMzJ3cpVBSaqc6B02fYeRaHxSDGbf18HuAY7Lg5OJqp9btZDHH1oy\nKcpyBCaKFo0syypwY+LPrwFvAVuzzLFBoN3qHoFAXcXn7be11Vf0eQCqqvGjN89w+KMBhsLjNNf7\nDE0GSFc+0VNDW+u8op8XT6ToPj9ieK77/Ahff7oWn+fmT//S8+uIJ1KEogqBBu+Ucxl+sPv0lHSb\n6bQ21bLsrpbJz+ZThlRUIRwzFiGkB2/fOX2Nxx9ZSnDU+PuKJ9TJZNWRqMIb714A4Pe231/Q+xVC\nyYEASZKeJC2ax4DerFM5884rvVLKbGU5T0+4HLEYaQ/U+1ATySnlLDRYMBgaYyg0bnhuODzO+Ysj\nkz3C9HuPRsaZ/g0pSZVDp65aPnPlspYpn7Uqw9BEGRr9XlpMfA5Ipwm99d5FxscTNNebXzedQ6f6\n+eKvLZ7xXbnB8P3MMGtgSw0EbAX+CNgmy3JEkqSYJEm1siyPAwsB41Gz2wirsKcR2bZ2tk8wElVo\n8ntY09nKzi1dlsGCfBzgQhzrXFG2z963YIa5Y1UGAfjlkcvs3NJl6nNk030+yMrlrew7bi3cDKVE\nxvKh6DCNJEmNwJ8BX5ZlOThxeA/w9MT/nwb+qbTi3frkqnCZ7ri53svm9YumVL7sBSUAwrEE+070\n8/IrH6Jq5oOSblGgzldjeG7FkqYZ9861WIVVekxLg5cXtkozhGaV/qPpsO9EP6/t7WP7I3fj81j3\nnqHROJvXLWLz+kW0NPhwCenn+jzG1bfcKTul9DQ7gFbg7yVJyhz7beCHkiR9HbgE/G1pxbv1sWpx\ngcm5Lqs6W6c46FY91KeDMXbt6eWFxyTD86/t7TOMSIkuOPTRNc5dDnEjbjxdwCgQYR2FajM1GXds\nWo6q6Rw4cdUwI/tEzzCfX9WBYjCBLptAvY/mBt+MQdPXD5yflZSdUgIB3we+b3BqS/HFmXtYVbhs\nuvtGUDbezPOKxBRLG/5kzzDPbVw+o3JYiU2d6Jys7puJWE03bYqJQokuF1sfWGxqVoVG46DreD2i\n4czTDNkiyB40NSrT51Z18PhDS0zvZQdORkAFyLS4+09cxWyi7Eg0zoWrEZYubMRbI9Lo99Lk95hG\nl8I3FMPKncsczId/PHKZr0zzm4pNj7Fy9m+aUeYTgzau6TAVplGZFnU0lT3Y40wNqACZFtdqZrlL\ngD//u5P88Q8Os2tPD25RYE1nq+n1zSZ2u5X/kS8HJvwNIzItfUYwuVLuvTUiq0zeY1VnC+NKinjC\n3D/b+mtLcmZITC9TuXF6mgqRT3gVpg0Abumi72rU0D8xs9vzNQdzcVweshxkLST6Zjb2IGD9vbQ0\nzO4cHDOcnqZCWEWTjDjRM0xK1dMTyNYuJOD3IggQ8HvZuHahpS+xY9NyNq9fxPxALYIA3prCf+bQ\nqGI50zLf6JuSVDlpMup/sjc9+FnsOm+zhSOaCpKpzC0NPgSgvs44LAw3xxrSdnsnqztbaJznIRRT\n6O4b5rW9faZhZ9HlYsem5ay/5w4a6mpQps2Z8XlEOlqtxzCsFhMsZC2AfDKNs7+XdDjZNyP8XioJ\nNcHQ2AgJ1TwDIV8c86yCTHdca71uXn7laM4s3Nf29rEva3ZlPsmd0/O+soknVO65M4DocpkmS66V\nzEPJhaTc5zPQave8o2xUTeXnfb+ge+gMISVMwNvEyrZ7eWr5byC6inuG09MUgF3zzDOOa32dJ6dp\nUswKL/lkIZzsHeE//Iu1PLqmY4r55vOIbFpnbf4VshZAIcvsFurQxxIx5GAfsYT5UlY/7/sF+68c\nJKiE0NEJKiH2XznIz/t+kdczjHB6mjwo59yXXOMfxUykyifsHBqNExtL8FtbV7BjUydD4XHQddry\nqLSFptzbmWmcUBMMj43wo49/wvUbg2houBCYX9fGv1n7Ddqon3Jt99AZw/ucHj7Dk8u24RE9BZfB\nEU0elHPuSy7TJB/zZnrSZa4sBICGeR5qvemf31sjsshkZRszChGCqqfY8tkWNq5v49JQkM4Fd9BS\nP/N5CTVBRBml0Vs/ozKrmsrPet/g1PAZIkp0yjkNnWtjg/zhwT9hW9cGtnU8hugSiSijhJSwYfmD\n8TARZZS2upaC3hsc0eSkXPPMp2M1/9+sVV/V2cLrB84b9oCrOlvZe8w8wTEcS/DyK0cL7jGzBWol\n9oSaIBgPs//TQ3w0cpZQPAw66AII5320u5fx7x/dicddMymIk4MfEU2OEvA2sqrt/km/Q9VU/vuH\n/4urMev8Xw2Nt3r2MTaW5NmuJ2j01hPwNhFUZq6D4BU9+D3FJXTeEqKp1nn05c6mzWDWquu6btgD\nqpqeM58r+3rI3WNamahNDW4iSphGVz2qrvLTnjfoCZ2f2soL6X8CgCfOAGf47v5d/OGmr/Dfjv4V\n/TcGJi8NKRH2XzlIUkuwc8Uz/LTnjZyCySbb9FrZdi/7rxyccU1cVfh/F97m2a4n8r5vhqoWTaX3\nWjQSZ7nmmReCkQkH8Mc/OGx4vVmC5CQuFaFGQU+m73Psk0s8Hl9Mva/WtIGaNFFdKkLtGMFkmH8+\nM8R5DpGovUZICeNxeUhqSTTMR/izGUhdYNfZ16cIJptD/UdAh9PDZ/O6X4Zs0+vLd2/h/YGjKOrM\n369Yv6aqRVPuefQZVE1j19s9nOgdJhxL0JIlTjvnmZfaY2abcAMjN3JmFwBTBaIJuBfLiIHrCN44\nqCK6DuNule8cPUKdsohwz1JC0eSUBiql6hzvuY578ce4266CeLMXuyYAE8VQtMJy3vSacU4NGjvq\nGQ4NHLE8b0Szr4lGbzogEEuOmY7NFOvXVK1oKuVL5LMaSqnRn3L0mLmXjNWmCERXfOipGkR/VjKj\nW51McYkkI0RcEZLNI+ix+6d8B5vXLSLadIKaBaWl5kxH0ETGBePZndm4cOXdewHc33rvZO9h5dcE\nvDfFVQhVK5py+BLxRGrGTmK79vTmtRpKKYNvdveYSlI1nX+fwb1Ypqb90uTfgi8O5F7Q3N12DbF5\nEHWog9Snn+F4zyBqezc1d9grGACXS6De3UAkFbW8zkwwC+rmM7+2jU9HrxJORGj2NfHrS1azreOx\nyWs8ogfveAe4DIIB4x1zK+Rspy+Raem7z48wFBqfbOm3P3I3J3vMV0MJZs0tKWU3M7t7zBkNyhQT\nTAR3ArHZfBEMKwQBBLeGq/0Krvowo7FmDl+/nMeKD4WjCSlWtHbywbVjltc1ewPc27KCj4PnGImH\naPI0cH/rvTzb9QSiS5wSql64oGXG+grhnqUkG2OIgUEEzzh6ohY1NJ9wZBnKhvzXqs5QtaKx05cw\na+nH4inCFkmJjX4P/roadu3pKdq0srPHzAhXJUlDIEkkAu5FvWkTzKOgJ7zoqhtBTCB4jGdmFoLo\nj0Ft+RY/ERB4cuk2PC4v7/UfRjXpUVa2pQViNo7jET2mfkkkphCKJtGj95C62jWlcQkLiaIslqoV\nDdgzkmzV0p+7FLIcBFzT2crudz/J27QqV/RN1TR27T3HsYu9jDX2IM6LIixX8GkgZLUdgldh0iu3\nC1f+vkSh6OgktBS/uWI7T3V+iSvRAX51eR9XRvsnza37W9N5YmAtDjOmfP+aiK7cFEix0c+qFo0d\niXxWLX04pvDQvQsM1/NaPN/P048u59t/84HhZ7NNKytHP1ePKbhUhsbC+Fx1jI/rU94xPUAY4fvv\nv8k1XYa7dWqyzaRKDFmVaJZ5XV7TqFpzliPuET0sDdzJNwK/Y5kZUPDzy7DKZlWLJkMpey3maumf\n39JFrc/NiZ5hgqNxmuZ5Wd3Vys7NnYxE4nmZVrkc/UzPeLz3GuFEiIZ5HlbftRj34rP8l8OvE4yH\nIFlLcmQ+DZGVrOwKoLd/RE/oPOFEBFxlcSnyQ8fw4QICusU05QwPtq+jL3LRcHByZdt9hqIopkex\nwu5VNm8J0ZRCrpamzus27c3yzfs63juA4B276YgDuNTJQcM6jwdx8VlU91G8WgIFOIoL7eqE6SMA\nnnFq2i8Rqx/hsDaOcL169qd8cMED9Ib7CMbD1Oh1qOFmbpxfjv/uS4jNgySEGwS8jfjctdxIjBFN\nRmeYVj/teYPTw2cIJ6K0+AJTzhVLvsEZu6ce2L4luiRJfwk8SLqNekmW5aNm11ZqS/Ts6NlweNxy\nH/rp/HjPx+w/fWGqIGBiktQy/s+Z3bz/6UkEr4KueFCjLaCJiE3DCN44TZ5G5nnqTEe9qx31hp+X\nH/4WTQ1ufryvm0PHQ1O+B1wqn1sb4IWNK/GIHlPTSkmqDEdHoSZB27ymkswuK3N4wR2Nt9aW6JIk\nbQA6J3YSuAf4EfCQnc8ohkxL8/WnayeXQ83V0mQmL/XUnsG3MgSpWlLhAP7gatZ1pldI+Vnvm3ww\n9AEuX/ozgi+ByzdVHJFkhEgyUq5Xs5Up7acO2lg98658nka/F12Dcz3KVMEAaCLnehLoG0QQZ5pW\n5RjYtTKHX3p+XVH3LAS7zbMvALsBZFk+K0lSQJKkBlmWrUevKoTP487LN0qoCf5O/oeb4wcCUDOO\nu20c7Y5hXAvWkVAX88HAh+UtcCloFDzFUA03k7ws4fKk0MbrIeXh8+vb8daIXBkctVy03Ww8qxwD\nu1bjXvFEquB7FordolkAZI9UDU0cMxRNpXcNUFIJ1No4AV8jXrdnxrn+6CBv9rzN2es9jMSN52Ek\ntATv9r9PX/RCwblWlSQ1uIRHVi/g8LWjOR12L37ig60kL3YBLjQFar1uvvDgYn73Nz7D3751lvc/\nMjcvW5tquWtxgF2/lCd3RmhrqmX9PXdwqs948Nho94J8GBi+YbqDQGg0Tiiq0F7mnSHKHQiwDPpU\nateAjKl1JniW4bHglHniAK/3vsnha8cMM2HNGIhdN40szSZaUkQdXkRDZBW//9AWao6IvHP1fcPr\nhPAilP4ljCd9M8yuBz8zn6cevpvvvX4q53JQK5e18MPdp6dcNxga5633Lpp+ZvruBfmiJlXTHQQC\n9T4CDV47fRrD43aLpp90z5Khg/Q2grNKZp54hsw88QwHrr5X1H3LrRldT6e15LxOAz3hQx1tJnXp\nHtBqWLt+Pj6Pm2c6n0DAxeErp1CIoSk+iLWgXFwBmvlqON3ng4yOJSzXGmiu97JWauNLDy7h5VeM\nTVWXgOE0hWIHFnNFQ30ed95baRSL3aL5FfCfgb+WJGkt0C/LcuU3hMnCap5499BHeY01mGGHYHQN\nBBPfI3W9A0QQ64MInjgeoRZR9xJX4+hiHCFVywL3XSzQP4N8SSEcTdEybQxCdIk8Jz3J40u38uq+\n03z8yQ0io7lH+UOjca4MxkzHqQQB/tUzKzl0eoCXX/nQdPlcs3k9paxpNjnuJQ8RGlUITIi3Utuu\n2yoaWZbfkyTpmCRJ75F2Rb9p5/2LwXKeuMnxfBFStSRG2nC3XkFw559uouugJ9LhaX9wFQ9sjHDk\n2jESWrrieV0e6pWlxEKdRGJJGhrc3NNZh1ur5cCx61MSND/RRJatD/BfX7Qeg/iHA5d5/3j+7Veg\n3sei+X7Tcarmeh/vnOrPuWeMS4CHV7Vz5kLItu37MmR64Xx6Yzux3aeRZfkP7L5nKWkVVvMpmr1N\n6OiElOJCwq5oO6nLXaSudOJechaxYeRm4mTKg8t3w1BMqaEOUpfvBU1k/fpFPNf1KKkrEicuXiQa\nS1LrbULqbGf7i0uJjSVo9HtJJFW+/aOJCVnTcqgyKT1W264XsrEUwOrOlsklpoxMoZXLmuk2cfKz\n0XT40q/fyfNf6LJtyvpcCznbih0LvVnNE1/Zdh+A4bnpdMxbwFhqjIgySrOvieX1XRw40pg+qdWQ\nuriS1PQUfVcS/7IexKYgSWEMjz4PNTgf5dJyWurrJlvc1/b2se/Da4AP8BGMq5OVIHP+2LkhUxMo\nV7Z0MTsJpCb25TBLQdm4ZiH7T+Set988sVJnKalQ2czFkLOtWDnwhSyIkImSfRw8y9BYcEaKh65r\n6ehZKr0usa6BoAvg0tETtSz0LOUPHvgKqq5O9ni6JnL63cNTTZeJHsAlwMOr29n6wBKaGzYjuKZ+\nLrvFzVUJVE3PaQLlcqrzWdJpOu+eGsA1sSSuUQqKklTzuqfVSp3FkGuqRSiqlL1SV61o7FzoTXSJ\nPNv1BA2BZzl/tX+GmfectJ0v3rmVb7+6j0gsiZ5It4iZXiPin0dqA3hrska7RUxNlw2rO3hh64rs\nEkz5XHaLa1UJgtG45SS5DLmc6mJ2EtB02Hf8KqJLYOfmrhk9Ra57tjTY57tkkysfMNDgZTSSewp1\nKVStaMqx0JvXbZ49Oz6uExn2oeObPJbxG8zMHzuyZ60qQaPfYzlJrsnvYf2K+Xk9z6isK5e3cKp3\niOCo+aLgVrNLze65ed0imht8ZVluay6GnG3D0oH3FbcgguXzLCqvp0bEb7DCf67s2XyycC0rQWcr\n3edHjFtVv5f/9NUHqK/Lt7c1LqvoEix7ICt/qZwLl1thd6p/oVStaKwc+OzVRuzCqvLGEyq73/3E\nNFdquulSaJKiVSUQRePV/9etaMtbMFZl3bFpOaqqceBkf9GDkHY5+fkyW2LNULWigZsO/OnhMwTj\n4RkOvN1sf2QpB7v7DbezK2QRjEKTFK0qQblbVdHlSvtfgmAYcKjWjZWg8mLNUNWiyTjwTy7bZtv0\nVytiYwkUk/0f810Eo5TVZ4wqQaVa1Z2bOxFdwqyZPLcSVS2aDHZPfzXDjkUwyrX2c7lb1dk2eW4l\nnE2dsihkAyIzCtnwqBqp9E7JtyKOaKZR6v6PdgjPobq5JcyzSmKHmTLbIVGH8uKIxoRSfAjHP5jb\nOKIpI7MVEnUoL45P4+BQII5oHBwKxBGNg0OBOKLJgZJUGQyNoSSrZ5lYh9nFCQSYUOlNch1uHRzR\nmFCpTXIdbj2cJtOAXEmXjql2e+OIxoB8ki4dbl+KMs8kSXIDfwMsm7jHt2RZPihJ0irge6QXn+yW\nZfn3bStpBbFzk1yHuUexPc0LwA1Zlh8Gvgb8j4nj/5P0njSfAxolSfqiDWWsOE7SpYMVxQYCXgV+\nMvH/IaBFkiQPcHfWJk5vApuBfyytiLODk3TpYEZRopFlOQlk9tz+18AuoBXIXgVjEGi3uo/ZTlPl\nxGwleCNeen4dj/+7/1sHtI9E4wMvPb+uMtsc2Egh7ztXKPc75xSNJEkvAi9OO/xtWZZ/KUnSN4G1\nwOPAdHumyjahKI43/+LJMeD8bJfDoXrIKRpZln8I/HD6cUmSvkZaLNtlWU5KkjQEZM9JXkh66w0H\nhzlFUYEASZKWAt8AnpJlOQ6TJts5SZIenrjsKeCfbCmlg0MVUdTuzpIk/Snwm8DlrMOPAcuBvyYt\nxg9kWf63dhTSwaGasH1LdAeHuY6TEeDgUCCOaBwcCuS2ynI2S/+Z3VKVB0mS/hJ4kHRK00tZg85z\nEkmSvgs8Qvp3/Y4syz8v17Nut57GLP1nTiFJ0gagU5blh0i/51/NcpHKiiRJG4H7Jt53G+l0rrJx\nu4nmVSAT0Zs+rjSX+AKwG0CW5bNAQJKkhtktUll5B3h24v9hYJ4kSWVLELytzDOT9J+5yALgWNbf\nQxPHorNTnPIiy7IK3Jj482vAWxPHysKcFU0B6T+3A3MipSkXkiQ9SVo0j5XzOXNWNPmm/1S8YJWh\nn3TPkqEDGJilslQESZK2An8EbJNlubg97vPktvJpjNJ/5ii/Ap4BkCRpLdAvy3K5t6KcNSRJagT+\nDPiyLMvBcj9vzvY0JrxI2vl/S5KkzLHHZFk236n1FkSW5fckSTomSdJ7gAZ8c7bLVGZ2kJ6a8vdZ\nv+tvybJ82fwjxeOk0Tg4FMhtZZ45ONiBIxoHhwJxROPgUCCOaBwcCsQRjYNDgTiicXAoEEc0Dg4F\n8v8BVsZonU7svPQAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 216x108 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "metadata": {
        "id": "_E7LhkkoyR72",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Example Model"
      ]
    },
    {
      "metadata": {
        "id": "oFKJ9_mfIUlQ",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class Net(nn.Module):\n",
        "    def __init__(self, input_size, hidden_size, output1_size, output2_size):\n",
        "        super(Net, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_size, hidden_size) \n",
        "        self.relu = nn.ReLU()\n",
        "        self.fc2 = nn.Linear(hidden_size, output1_size)\n",
        "        self.fc3 = nn.Linear(hidden_size, output2_size)\n",
        "    \n",
        "    def forward(self, x):\n",
        "        out = self.fc1(x)\n",
        "        out = self.relu(out)\n",
        "        out1 = self.fc2(out)\n",
        "        out2 = self.fc3(out)\n",
        "        return out1, out2"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "r-UlDkdlK21o",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "model = Net(Q, nb_features, D1, D2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "EyR5ree1yWDk",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Define task dependent log_variance"
      ]
    },
    {
      "metadata": {
        "id": "ih0Ocq-NFuX2",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "log_var_a = torch.zeros((1,), requires_grad=True)\n",
        "log_var_b = torch.zeros((1,), requires_grad=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Q_aXEqn84eGR",
        "colab_type": "code",
        "outputId": "b43872e7-3303-4123-9e28-f5ff9746cae6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "cell_type": "code",
      "source": [
        "# Initialized standard deviations (ground truth is 10 and 1):\n",
        "std_1 = torch.exp(log_var_a)**0.5\n",
        "std_2 = torch.exp(log_var_b)**0.5\n",
        "print([std_1.item(), std_2.item()])"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[1.0, 1.0]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "mcKDv3htMqXJ",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# get all parameters (model parameters + task dependent log variances)\n",
        "params = ([p for p in model.parameters()] + [log_var_a] + [log_var_b])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "HfBN-qs2Q_Wj",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "#optimizer = optim.SGD(params, lr=0.001, momentum=0.9)\n",
        "optimizer = optim.Adam(params)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "MGPCPAXryqT9",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## define loss criterion"
      ]
    },
    {
      "metadata": {
        "id": "5rHlOTALGXug",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def criterion(y_pred, y_true, log_vars):\n",
        "  loss = 0\n",
        "  for i in range(len(y_pred)):\n",
        "    precision = torch.exp(-log_vars[i])\n",
        "    diff = (y_pred[i]-y_true[i])**2.\n",
        "    loss += torch.sum(precision * diff + log_vars[i], -1)\n",
        "  return torch.mean(loss)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "9Mk7ma7my397",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Train the network"
      ]
    },
    {
      "metadata": {
        "id": "_pQ3cDYGVtPv",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# convert data into torch from numpy array\n",
        "X = X.astype('float32')\n",
        "Y1 = Y1.astype('float32')\n",
        "Y2 = Y2.astype('float32')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "mFm-Jtkm2cGk",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def shuffle_data(X,Y1,Y2):\n",
        "    s = np.arange(X.shape[0])\n",
        "    np.random.shuffle(s)\n",
        "    return X[s], Y1[s], Y2[s]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "2RJ0U_ss0uT1",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "## Train Network\n",
        "loss_history = np.zeros(nb_epoch)\n",
        "\n",
        "for i in range(nb_epoch):\n",
        "\n",
        "    epoch_loss = 0\n",
        "    \n",
        "    X, Y1, Y2 = shuffle_data(X, Y1, Y2)\n",
        "    \n",
        "    for j in range(N//batch_size):\n",
        "        \n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        inp = torch.from_numpy(X[(j*batch_size):((j+1)*batch_size)])\n",
        "        target1 = torch.from_numpy(Y1[(j*batch_size):((j+1)*batch_size)])\n",
        "        target2 = torch.from_numpy(Y2[(j*batch_size):((j+1)*batch_size)])\n",
        "        \n",
        "        out = model(inp)\n",
        "        \n",
        "        loss = criterion(out, [target1, target2], [log_var_a, log_var_b])\n",
        "        \n",
        "        epoch_loss += loss.item()\n",
        "        \n",
        "        loss.backward()\n",
        "        \n",
        "        optimizer.step()\n",
        "   \n",
        "    loss_history[i] = epoch_loss * batch_size / N    "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ZoiJCMsb1BMt",
        "colab_type": "code",
        "outputId": "270cbd11-8e7f-4426-acfa-3bf789a1fc8f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        }
      },
      "cell_type": "code",
      "source": [
        "# plot loss history\n",
        "\n",
        "pylab.plot(loss_history)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[<matplotlib.lines.Line2D at 0x7fec2dfdf6a0>]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl4XPV97/H3mU27Jdka4zVm9Zdg\ndtthCRQTliQ0S28hTRPCzUKfpA1Jk6akJTe3pJC0uReacHsJT1IKgYQkt2FpiskCAROWAMFLcMCA\nf17YvGLJlmVLlqWZ0dw/zpE8EiN5NJrRLPq8HvzM6JwzM59nNHzO0e+cOcdLp9OIiEh1CJU6gIiI\nFI5KXUSkiqjURUSqiEpdRKSKqNRFRKpIpJQv3t6+P+9Db1pb6+nsPFDIOAWhXONTrrmgfLMp1/hU\nY654vMkbbV7FbqlHIuFSR8hKucanXHNB+WZTrvGZarkqttRFROStVOoiIlVEpS4iUkVU6iIiVUSl\nLiJSRVTqIiJVJKfj1M3sBuDcYPlvAquAu4AwsAO4wjnXZ2aXA18EBoBbnXO3FyW1iIhkddgtdTM7\nHzjROXcW8B7g/wDXA7c4584FNgGfMrMG4FrgQmAZ8DdmNr0Yobfu6uZHD77MgE4bLCIyTC7DL08A\nHwru7wUa8Et7eTDtAfwiPwNY5Zzrcs71Ak8B7yxo2sDjf9jOTx/ewM7d5fctMRGRUjrs8ItzLgX0\nBD9eCfwSeLdzri+YtguYDcwC2jMeOjh9VK2t9Xl9q6q2NgpAc0s98XjTuB9fbOWYCZQrH+WaTbnG\nZyrlyvncL2b2QfxSvxjYmDFrtHMQjHpugkH5nvfgYG8CgD17emiIHPZlJlU83kR7+/5Sx3gL5Rq/\ncs2mXONTjbnGWhnkdPSLmb0b+CrwXudcF9BtZnXB7LnA9uDfrIyHDU4vOC/ocV2KT0RkuFx2lDYD\nNwLvc87tCSY/Alwa3L8UeBB4FlhqZi1m1og/nv5k4SMz9DeAOl1EZLhchl8+DLQBd5vZ4LSPA7eZ\n2WeA14EfOOcSZnYN8BCQBq4LtuoLLuSV15CLiEi5yGVH6a3ArVlmXZRl2XuBewuQKyc6pFFEZLiK\n/EapNtRFRLKr0FL3W10b6iIiw1VmqQe3OvpFRGS4yiz1wS31EucQESk3FVrq/q221EVEhqvIUh+k\nThcRGa4iS93T4S8iIllVZqkHtxp+EREZrjJLXacJEBHJqkJLffA4dbW6iEimyiz14FaVLiIyXGWW\n+uDwS2ljiIiUnYosdTT8IiKSVUWWekjjLyIiWVVkqQ8aUKmLiAxTkaV+6MtHanURkUw5XXjazE4E\n7gducs59x8zuAeLB7OnA74B/Bl4A1gTT251zHypwXkDHqYuIjOawpW5mDcDNwIrBaZllbWbfB247\nNMstK3DGt/DQ+dRFRLLJZfilD7gE2D5yhvkXLW1xzq0sdLCxHDqkUa0uIpIpl2uUJoFkxkWnM30B\nfyt+0CwzuxeYA9zinPvxWM/d2lpPJBIeR1xfY2MtAE1NdcTjTeN+fLGVYyZQrnyUazblGp+plCun\nMfVszCwGnOOc+2wwaTfwD8CPgGZgpZk96pzbMdpzdHYeyOu1e3r6AOjq6qW9fX9ez1Es8XhT2WUC\n5cpHuWZTrvGpxlxjrQzyLnXgPGBo2MU5tx+4I/ixw8xWA8cDo5Z6vg6deVfDLyIimSZySONS4A+D\nP5jZ+Wb27eB+A3AqsGFi8bI7dOrdYjy7iEjlyuXol8XAt4AjgYSZXQb8KTAb2Jyx6JPAx83sGSAM\nfNM5t63gidE1SkVERpPLjtI1wLIssz4/Yrkk8ImCpDoMXaNURCS7yvxGaXCrThcRGa4yS31o+EWt\nLiKSqSJLHZ0mQEQkq4osdR3RKCKSXWWWejD8MqBNdRGRYSqz1EsdQESkTFVmqXs6S6OISDYVWur+\nrY5TFxEZriJLfZAqXURkuIos9ZCnUXURkWwqstQH95Tq6BcRkeEqstR13WkRkewqs9SHrlGqVhcR\nyVSZpT50jVIREclUoaWu49RFRLKpzFIPbjX8IiIyXGWWuoZfRESyyunC02Z2InA/cJNz7jtmdiew\nGNgdLHKjc+4XZnY58EVgALjVOXd7ETIDGn4REckml2uUNgA3AytGzPqKc+7nI5a7FngH0A+sMrOf\nOef2FDAvkHlIo1pdRCRTLsMvfcAlwPbDLHcGsMo51+Wc6wWeAt45wXxZafhFRCS7XC48nQSSZjZy\n1ufM7EvALuBzwCygPWP+LmD2WM/d2lpPJBIeV2CAll09ANTX1xCPN4378cVWjplAufJRrtmUa3ym\nUq6cxtSzuAvY7Zxba2bXAP8IPD1imcOeoKWz80BeL75vXy8A3T0HaW/fn9dzFEs83lR2mUC58lGu\n2ZRrfKox11grg7xK3TmXOb6+HPgucC/+1vqgucDv8nn+w/F0jVIRkazyOqTRzO4zs6ODH5cB64Bn\ngaVm1mJmjfjj6U8WJOVb6DQBIiLZ5HL0y2LgW8CRQMLMLsM/GuanZnYA6AY+6ZzrDYZiHsLfh3md\nc66rGKFD2lIXEckqlx2la/C3xke6L8uy9+IPwxSVF9KFp0VEsqnIb5SGg0H1gQGVuohIpoos9ZC2\n1EVEsqrMUh/aUi9xEBGRMlORpe4FqbWlLiIyXEWWekhj6iIiWVV2qWtLXURkmMos9WBHaVpj6iIi\nw1RmqQdfPtKWuojIcJVZ6jqkUUQkq8osde0oFRHJqiJLXacJEBHJriJLfWhMXTtKRUSGqdBS16l3\nRUSyqcxS1/CLiEhWlVnq2lEqIpJVZZb60JZ6iYOIiJSZnK5RamYnAvcDNznnvmNm84E7gCiQAD7m\nnNtpZgngqYyHXuCcSxU69KEdpWp1EZFMuVzOrgH/8nWZF5v+BnCrc+5uM7sK+BLwd0CXc25ZMYJm\n8nTuFxGRrHIZfukDLgG2Z0z7LIcuZ9cOzChwrjFpR6mISHa5XKM0CSTNLHNaD4CZhYGrgOuDWbVm\n9hNgAXCfc+7bYz13a2s9kUh43KETSX9EJxIJE483jfvxxVaOmUC58lGu2ZRrfKZSrpzG1LMJCv0u\n4FHn3ODQzNXAj4A08ISZPeGcWz3ac3R2HsjrtVPBt476+pK0t+/P6zmKJR5vKrtMoFz5KNdsyjU+\n1ZhrrJVB3qWOv6N0o3PuusEJzrnvDd43sxXAScCopZ4vHdIoIpJdXqVuZpcD/c65r2VMM+BrwOVA\nGHgncG8hQo7keR6epzF1EZGRcjn6ZTHwLeBIIGFmlwEzgYNm9liw2EvOuc+a2RZgJTAALHfOrSxK\naiAc8rSlLiIyQi47StcAy3J5Mufc3080UK6ikRCJlM7oJSKSqSK/UQoQi4ZJJFXqIiKZKrrU+xMq\ndRGRTJVb6pHQ0PHqIiLiq9xSj4Y1pi4iMkJFl7qGX0REhqvcUo+ESQ2kdVijiEiGyi31qB+9X+Pq\nIiJDKrbUa2L+icD6NAQjIjKkYku9ubEGgH09/SVOIiJSPiq21Gc01wKwt7uvxElERMpH5Zb6NL/U\nO/er1EVEBlVsqU9vrgO0pS4ikqliS31wS31vt8bURUQGVW6pD46pa/hFRGRIxZZ6Q12UaCREp4Zf\nRESGVGype55HW3MtuzoPkNYVkEREgAoudYC58UZ6+1I6AkZEJJDTNUrN7ETgfuAm59x3zGw+cBf+\ntUh3AFc45/qCa5d+Ef9ydrc6524vUm4A5rU1sBrY2t7D9GDHqYjIVHbYLXUzawBuBlZkTL4euMU5\ndy6wCfhUsNy1wIX4l7/7GzObXvDEGebGGwDY1tFdzJcREakYuQy/9AGXANszpi0Dlgf3H8Av8jOA\nVc65LudcL/AU8M7CRX2refFGAN54U6UuIgK5XXg6CSTNLHNyg3NucCB7FzAbmAW0ZywzOH1Ura31\nRCLhcQXOtGjhTJobY2ze1kVbWyOe5+X9XIUUjzeVOkJWyjV+5ZpNucZnKuXKaUz9MEZr0sM2bGfn\ngbxfNB5voqOjm2PnNLNmQzvrN7XT1lKX9/MVSjzeRHv7/lLHeAvlGr9yzaZc41ONucZaGeR79Eu3\nmQ026Fz8oZnt+FvrjJheVAvntwDgtuwt9kuJiJS9fEv9EeDS4P6lwIPAs8BSM2sxs0b88fQnJx5x\nbIOlvnGrSl1E5LDDL2a2GPgWcCSQMLPLgMuBO83sM8DrwA+ccwkzuwZ4CEgD1znnuoqWPDB/ZiN1\nNWFeeq2TdDpdNuPqIiKlkMuO0jX4R7uMdFGWZe8F7p14rNyFQh4nH9PGsy+9yRtvdrNgVnnuEBER\nmQwV/Y3SQYsXxgFYs2FXiZOIiJRWVZT6SUfPIBYJsca1H35hEZEqVhWlXhMLs+io6ezYfYCde/I/\nTFJEpNJVRakDnHacPwSzer2GYERk6qqaUj99YRuRcIhnXtypU/GKyJRVNaVeXxvltOPa2LH7AC++\nuqfUcURESqJqSh3g4qXzAfjtCztKnEREpDSqqtSPnjONWdPr+f2GDvb16ILUIjL1VFWpe57HeafO\nIZka4Ol1O0sdR0Rk0lVVqQOcfeIsPA+efelN7TAVkSmn6kq9qT7Gqce28fqb+9m8bV+p44iITKqq\nK3WAC5f4O0wfXr2lxElERCZXVZb68W9rYV68kTWunT37DpY6jojIpKnKUvc8j4uWzGMgnebR328r\ndRwRkUlTlaUOcOaiI2isi/L42m30JVKljiMiMimqttSjkTDLTptLz8Ekjz+nrXURmRryuvC0mV0J\nXJExaQmwGmgAeoJpfxtcYKNkLl46n0dWb+FXz77BuxbPIxKu2nWYiAiQZ6k7524Hbgcws/OAPwMW\nAZ90zq0rXLyJaayL8kenzOHXq7bwzLqdnHvKnFJHEhEpqkJsul4LfL0Az1MUFy+dTyQcYvlTr5FM\nDZQ6johIUXkT+dalmS0FrnLOfcLMHgP2AG3Ay8AXnXO9Yz0+mUylI5Fw3q+fq3+//wWWP/EKf3Xp\nyVxy9lFFfz0RkSLzRpuR1/BLhr8A7gzu/yvwvHNus5l9F7gK+JexHtzZmf9ViuLxJtrb9+e07Pmn\nzOHBZ17j/z20nlOObCUWLd6KZDy5JpNyjV+5ZlOu8anGXPF406jzJjr8sgx4GsA59zPn3OZg+gPA\nSRN87oJpbohxweJ57O3u57G120sdR0SkaPIudTObA3Q75/rNzDOzR8ysJZi9DCibHaYA7z1jAbWx\nML985jX6+nXcuohUp4lsqc8GdgE459LArcAKM3sCmA/cMvF4hdNYF+XipfPZdyDBit9vLXUcEZGi\nyHtMPTgG/b0ZP98N3F2IUMVy8dK3sWLNVn71u9dZdupc6msnuktBRKS8TKlv49TXRnjPGW+j52CS\n5U+9Wuo4IiIFN6VKHeCiJfOZ2VLHI6u3smVXd6njiIgU1JQr9Vg0zEcvWshAOs1dDzkGdHUkEaki\nU67UAU4+ZgaLLc6mbV089cKOUscRESmYKVnqAB+54DhqomHu+c1munsTpY4jIlIQU7bUp0+r5YPn\nHEV3b4L/fHzz4R8gIlIBpmypA1y4ZB5z2xp4fO12Xt2hi1SLSOWb0qUeCYe4/KKFpIFbl7/Iwf5k\nqSOJiEzIlC51gOMXtPKed7yNNzt7+cnDG0sdR0RkQqZ8qQP86XlHs2BWE799YQfPvvRmqeOIiORN\npY4/DPOXH1hETTTMDx9aT/veMU8DLyJStlTqgSOm1/OxixfS25fi1gdeJDWgqySJSOVRqWc4+8RZ\nnHHCEWzeto/7f/taqeOIiIybSj2D53lccbHR1lzLL55+DfdGZ6kjiYiMi0p9hPraCJ/5wCI8z+PW\nB17St01FpKKo1LM4Zm4zf3LuUXTu79P4uohUlLyuEmFmy4B7gBeDSS8ANwB3AWFgB3CFc66vABlL\n4pIzF7BpWxfPb97NTx/dxEcvXFjqSCIihzWRLfXHnXPLgn+fB64HbnHOnQtsAj5VkIQlEgp5fOYD\ni5jT1sAjq7eyYo0ugSci5a+Qwy/LgOXB/QeACwv43CVRVxPhry87mca6KD9+eIO+mCQiZW8ipX6C\nmS03s9+a2UVAQ8Zwyy78C1NXvJktdXzpw6dQGwtz289f4uXX9pQ6kojIqLx0Hlf+MbO5wDn4F5o+\nGvgN0Oicmx7MPxb4oXPu7LGeJ5lMpSOR8LhfvxTWbtjFdbf9jmgkzNc/cxa2YHqpI4nI1OWNOiOf\nUh/JzFYCS4F651yvmZ0HfN45d9lYj2tv35/3i8fjTbS378/34XlZvX4X37v/RWLREF/80CksnN9S\nFrlyoVzjV67ZlGt8qjFXPN40aqnnNfxiZpeb2dXB/VnAEcAdwKXBIpcCD+bz3OVsyfEz+cwHF5FI\nDvDtu9fyooZiRKTM5Dumvhw4z8yeBO4H/gr4KvDxYNp04AeFiVhelh4/k6v+20kMDKT513ueZ+2m\njlJHEhEZktdx6s65/cD7s8y6aGJxKsOpx7XxhctO4eb7nueW/3yBT39gEUuPn1nqWCIi+kZpvhYd\nNZ0vffhUopEQ37t/HU+v21HqSCIiKvWJWDi/hav//DRqYxFu+/nL3Hzf8yRTOqWAiJSOSn2Cjp4z\nja987HSmNcR4bmMHf3PT4ySSqVLHEpEpSqVeAPPijfzDf18CwGs79nHDT55jz76DJU4lIlORSr1A\nZjTX8m9XL+O80+axefs+rr9zlb59KiKTTqVeQNFIiL+9/HQ+cuFx9BxM8i//sZb7Ht+sU/eKyKRR\nqReY53lctGQ+13zsdFqn1fCLZ17nW/+xVsMxIjIpVOpFcsycZq79xFJOOWYG69/Yy9e+v5KVL79J\nIU7LICIyGpV6EU2rj/HXl53MFe82EskBvnf/i3znP1+gc3/FXjtERMpcXt8oldx5nsf5p83lhAWt\n/ODB9Ty3sYP1b+zlw+86lnNOmk0oNOp5eURExk1b6pPkiOn1XP2R07ji3UY6nebOX63n6z9czes7\ny+/scSJSuVTqkygUbLV/4y/O4MxFR/D6zv1cd+cq/v2Bl+jq6S91PBGpAhp+KYHp02r59PsXcc5J\ns/nxwxt45sWdrN3UzvvPPorzT5tLTawyLhwiIuVHW+oldMKR07n+ynfwkQuPI+R53P2bTXz5u0+z\nYs1WnUNGRPKiLfUSC4dCXLRkPmctmsWDz77BI2u28OOHN/Dwqi388dkLOPOEWUQjWveKSG7UFmWi\nsS7KZcuO4Ya/PJvzT5/L7n0HueOX6/nqv/+OFWu2crA/WeqIIlIBtKVeZqY1xLjiYuOPz1zAgyvf\n4LHntvHjhzfwsyde4ZyTZ/OuxfOY2VJX6pgiUqbyLnUzuwE4N3iObwIfABYDu4NFbnTO/WLCCaeo\n6dNq+eiFC7nkzAU89tw2Hlu7nV+v2sLDq7aw+PiZXHD6XBbOb8HzdJy7iBySV6mb2fnAic65s8xs\nBvAc8CjwFefczwsZcKpraazhT849mvedfSSr1+/iwZVvsHr9Llav38WCWU2cf9pc3vH2mdTG9EeX\niOS/pf4EsDK4vxdoAHQcXhFFwiHOXDSLd5xwBBve2MuKNVv5/cZ27vzVeu56yHHyMTO4cPE8bEEr\nIW29i0xZ3kRPMGVmn8YfhkkBs4AYsAv4nHOuY6zHJpOpdCSidUG+2jt7eWTl6/zi6Vfp6va/vNTS\nVMNZJ83mkrOP4sjZ00qcUESKZNQttwmVupl9EPgfwMXAEmC3c26tmV0DzHPOfW6sx7e378/7xePx\nJtrby+8r9qXIlU6n2bxtH4//YRtrN3bQc9A/UuaI1jpOObaNU45t4+zT5tG5p2dSc+WiXH+PUL7Z\nlGt8qjFXPN40aqlPZEfpu4GvAu9xznUBKzJmLwe+m+9zy/h4nsex85o5dl4zydQAz23sYOVLb7Lu\n1T38etUWfr1qCw3/tY5FR7Zy6rFtnHj0DBrroqWOLSJFkO+O0mbgRuBC59yeYNp9wJedc68Ay4B1\nhQopuYuEQyw9fiZLj59JIjmAe6OTtZs6eOHVPax8eRcrX95FKFgJnHpsG6ccO4NZ0+t1FI1Ilch3\nS/3DQBtwt5kNTrsD+KmZHQC6gU9OPJ5MRDQS4sSjZ3Di0TNoa2tk7Us7Wbupgz9s6mDjlr1s2LKX\nu3+ziZmtdZxyTBsL5zdz7LwWmhtipY4uInnKq9Sdc7cCt2aZ9YOJxZFi8TyPeTMbmTezkfedfST7\nevp5fvNu/rCpg3Wv7eHh1Vt4ePUWAGa21HHsvGaOm+eX/OwZ9TqiRqRC6ODmKWpaQ4xzTp7NOSfP\nJpFM8cr2fWzc2sXGrV1s2tbF0+t28vS6nQA01EY4Zm5Q8nObOWr2NGJRHbUkUo5U6kI0Esbe1oq9\nrRWAgXSa7R09bNraxcate9m4tYvnN+/m+c3+l4XDIY8jZzX5O2fntnD0nGm0NMY0Li9SBlTq8hYh\nz2NevJF58UaWnTYXgL3dfWza2sWGrXvZtLWLV3fsZ/P2fTyEP2RTXxNhTlsDc9oaaGuu5fgFrcxt\na6CuRh8xkcmk/+MkJy2NNSw5fiZLjp8JQF9/ild27GPT1r28saub7R09vLJ9H5u2dQ173PRpNcxt\na2RuUPhz4w3MmdGgC4GIFIlKXfJSEwvz9gWtvH1B69C0RHKA7R09vLZzHzt2H2BbezfbOnp44ZXd\nvPDK7mGPb26McURLHfNmTaOxJky8pY625lriLXW0NNbogtwieVKpS8FEIyEWzGpiwaymYdN7DibY\n1t7Dto4etrf3sH13D7s6e9m4tYsNW7ve8jzhkEdbcy1tLXXEg9vBwm9rrqWxLqrxe5FRqNSl6Bpq\noyyc38LC+S3DpidTA6QjYTa8spv2rl469h6ko6uX9uD2xVf3ZH2+2liYtuY64i21tDXX0dZSS3ND\njJbGGqY31dDcWKOrRcmUpVKXkomEQ8TbGomOcv6hg/1JOvYeHCr8YcXf1cvW9u5Rn7uhNkJjfYym\nuiiNdVGa6qM01kdpqovRWBfcr48G82PU1YS19S9VQaUuZas2Fhn6wtRI6XSa7t4EHV0H6eg6SFd3\nH53dfezd38fe7n66evrp7k3Q3tnLQA4nrQuHvEPlXxelsT5GfHo9EXjLCqEpWCFEdYZRKUMqdalI\nnufRVB+jqT7GUWOcYnggnaa3L0n3gQT7exP+7QG/8Ad/7u71p+3vTbBnXx9b23M7m2VNNExNLExD\nbYTaWISG2gg10TCxaIim+hj1NRFqYmFikRDRiD89FgkTjYaCaSEaaqO0aLhICkilLlUt5Hk01EZp\nqI1yRI6PSaYG6OlNEK2N8ca2vUMrgP0H+g+tBIKf+/pT7D+QoH1vL8lUfmeS9oBYLExtsJKIRcLU\nREPEov4KIRwOEfKgriZCLBqmtbmOZCJJTTRMJOyvHKLBbWTo1iMaCQe3/vxIJGOZcEhHGFUplbrI\nCJFwiObGGuLxJuojuRVfOp2mPzlAfyJFX3+KfQcSHOxP0tef8qcnU/QnBkgE9xPJAbp6+tnddZBk\naoC+/hR9iRQHg5VEfyJFamBiF7A5nHDIIxIJEfb82wMHk0QjIVoaY4Q8b+i7BLGIv0KIhEKEwx7h\njJVBXU2EUMgj7Hk0NtbQdzDh/xzyCHkeXsgj5EE45K+YQiHv0D/P/xcOeXihQ8sMjpbFoqFhy4WG\nnhNI+3+FRcIhwiGPdNo/pDYW9VeCkZA3NOyW9ELs3dtL98EEtbEI4ZC/oksmB8Dzh/lSA2lIpwmH\nQ3je4BUoPP++Bx7+RA//r8TB3S/+fA+C149EPLwg78jPR5pDucNFXKGq1EUKwPM8fzgmGqapHtpa\n6ib8nKmBgaEVQWog7a8wEinqG2p5s2M//f0pEil/fnLoNk0imSKRSpNMDpBIZc7zbxOpgaF5g8+d\nTKWpiYboSwyw/0CCdDrNwf5UkKO4K5dq5QGjvXOxSIi/u2IJRx/x1v1FE6VSFylT4VCIupoQdTXD\np8fjTbQ1Tt5FTpIpv/hTqTTJgQFSwTBTauDQSmFgIE1zcz0de7pJpdKkBtKk02kG0mkGBghug3/B\n/VRwP51m6DmGT/fnDT0m47kGt5gHs4VDHpFwiN7+JOngOQaHl6LRCD0H+tm+u4d58QbSadjX08+0\nhhjpdJq+xMDQsqnUAHDor4XBLex0GtKkCf4byja4DPh/4SWDx/sb/n5OPC/I6y/veR7pdJppDSN+\nsQWiUheRMUXCISJh4DDrkXi8ieba8jsiqBovZzcW7XIXEakiBd9SN7ObgDPx/0r5gnNuVaFfQ0RE\nsivolrqZnQcc55w7C7gS+L+FfH4RERlboYdfLgD+C8A59zLQamajfzNEREQKqtDDL7OANRk/twfT\n9mVbuLW1nsgEvmodjzcdfqESUK7xKddcUL7ZlGt8plKuYh/9MuYR9p2dB/J+4qm2R3uilGv8yjWb\nco1PNeYaa2VQ6OGX7fhb5oPmADsK/BoiIjKKQpf6r4HLAMzsdGC7c678VpEiIlXKG/w2VKGY2f8C\n/ggYAK5yzv2hoC8gIiKjKnipi4hI6egbpSIiVUSlLiJSRVTqIiJVRKUuIlJFVOoiIlVEpS4iUkUq\n8iIZpT69r5ndAJyL//59E/gAsBjYHSxyo3PuF2Z2OfBF/GP2b3XO3V7ETMuAe4AXg0kvADcAdwFh\n/G/2XuGc65vMXEG2K4ErMiYtAVYDDUBPMO1vnXNrzOzLwIfwf7fXOed+WYQ8JwL3Azc5575jZvPJ\n8X0ysyhwJ7AASAGfdM69UsRcd+BfniIBfMw5t9PMEsBTGQ+9AH8DbbJy3UmOn/dJfr/uAeLB7OnA\n74B/xv9/YfCcVO3OuQ+ZWTPwE6AZ6AY+6pzbU6BcI/thFZP4+aq4Us88va+ZvR34PnDWJL7++cCJ\nwevPAJ4DHgW+4pz7ecZyDcC1wDuAfmCVmf2sUB+cUTzunLssI8MdwC3OuXvM7J+BT5nZDyc7V7DS\nuD3IdB7wZ8Ai/A/suoy8RwF/jv/7bAaeNLOHnHOpQmUJfi83AysyJl9Pju8T8H5gr3PucjO7GP9/\n2g8XKdc38P9nv9vMrgK+BPwd0OWcWzbi8R+bxFyQ4+edSXy/nHMfypj/feC2Q7OGv1/4ZfqYc+5G\nM/s08PfBv4nmytYPK5jEz1fgM7sdAAAEFUlEQVQlDr+U+vS+T+BvSQLsxd/azHaqyTOAVc65Ludc\nL/6W1TsnJ+KQZcDy4P4DwIVlkOta4OujzDsf+JVzrt851w68DpxQ4NfvAy7BP0/RoGXk/j5dAPws\nWPYRCvfeZcv1WeC+4H47MGOMx09mrmzK4f0CwMwMaHHOrRzj8Zm5Bn/nhZCtH5YxiZ+vSiz1Wfgf\n8EGDp/edFM65lHNucMjgSuCX+H8mfc7MHjWz/zCztiw5dwGzixzvBDNbbma/NbOLgAbnXN+I1y9F\nLgDMbCmwxTm3M5h0vZk9YWb/ZmZ1k5HNOZcM/ifKNJ73aWi6c24ASJtZrBi5nHM9zrmUmYWBq/CH\nCwBqzewnZvaUmX0pmDZpuQK5ft4nOxfAF/C34gfNMrN7zezpYMiDEXkL9jkbpR8m9fNViaU+0pin\n9y0WM/sg/i/tc/jjZdc4594FrAX+MctDip1zI3Ad8EHg4/jDHZnDa6O9/mS+f3+BP14I8K/Al51z\nQ+cJyrJ8KX63432fipoxKPS7gEedc4NDDVcDnwYuBi43syWTnGsin/div18x4Bzn3G+CSbuBfwA+\ngr/v6+tmNrLAC55pRD/k8loFe78qsdRLfnpfM3s38FXgvcGfTyucc2uD2cuBk7LknMvh/4TNm3Nu\nm3Pup865tHNuM7ATf2iqbsTrT2quEZYBTwd5fxbkBP9P0kl/zzJ0j+N9Gpoe7NTynHP9Rcx2B7DR\nOXfd4ATn3Pecc93BFuEKRrx3xc41zs/7ZL9f5wFDwy7Ouf3OuTuccwnnXAf+DvrjR+Qt6OdsZD8w\nyZ+vSiz1kp7eN9hrfiPwvsGdi2Z2n5kdHSyyDFgHPAssNbMWM2vEHxt7soi5Ljezq4P7s4Aj8Avh\n0mCRS4EHJztXRr45QLdzrt/MPDN7xMxagtnL8N+zR4E/NrNYsPxc4KViZ8Mfu8z1ffo1h8ZM3w/8\nhiIJhgr6nXNfy5hmwdCLZ2aRINeLk5xrPJ/3ScsVWAoMnRnWzM43s28H9xuAU4ENI3IN/s4nLFs/\nMMmfr4o8S6OV8PS+wZ7yf8T/YAy6A//PrAP4h0d90jm3y8wuA76Mf3jezc65HxcxVxP+mGsLEMMf\ninkO+CFQi7/T8ZPOucRk5srItxj4hnPuvcHPf4Z/tEEPsA240jl3wMw+D1weZPufGUMOhczxLeBI\n/MMEtwWvdyc5vE/BcMhtwHH4O+s+4ZzbUqRcM4GDHLoc5EvOuc+a2f8G3oX/+V/unPunSc51M3AN\nOXzeJznXn+J/7n/rnPtpsFwkeH3DP6Dhu865O4Ii/RH+zue9+IeLdhUgV7Z++HiQYVI+XxVZ6iIi\nkl0lDr+IiMgoVOoiIlVEpS4iUkVU6iIiVUSlLiJSRVTqIiJVRKUuIlJF/j9M7v0+O93/0QAAAABJ\nRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "metadata": {
        "id": "tudLOfJazx4X",
        "colab_type": "code",
        "outputId": "9c9cbd93-433a-48ff-e049-22ccb0cdd861",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "cell_type": "code",
      "source": [
        "# Found standard deviations (ground truth is 10 and 1):\n",
        "std_1 = torch.exp(log_var_a)**0.5\n",
        "std_2 = torch.exp(log_var_b)**0.5\n",
        "print([std_1.item(), std_2.item()])"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[8.587499618530273, 0.9198104739189148]\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}